#!/usr/bin/env python
# coding=utf-8
"""
Build a cropped dataset from DIV2K.
This script downloads DIV2K datasets and creates random crops from them.
"""

import os
import re
import zipfile
import random
import argparse
import concurrent.futures
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import requests
import numpy as np
from PIL import Image
from tqdm import tqdm

# DIV2K download URLs
_DL_URL = "https://data.vision.ee.ethz.ch/cvl/DIV2K/"

_DL_URLS = {
    "train_hr": _DL_URL + "DIV2K_train_HR.zip",
    "valid_hr": _DL_URL + "DIV2K_valid_HR.zip",
    "train_bicubic_x2": _DL_URL + "DIV2K_train_LR_bicubic_X2.zip",
    "train_unknown_x2": _DL_URL + "DIV2K_train_LR_unknown_X2.zip",
    "valid_bicubic_x2": _DL_URL + "DIV2K_valid_LR_bicubic_X2.zip",
    "valid_unknown_x2": _DL_URL + "DIV2K_valid_LR_unknown_X2.zip",
    "train_bicubic_x3": _DL_URL + "DIV2K_train_LR_bicubic_X3.zip",
    "train_unknown_x3": _DL_URL + "DIV2K_train_LR_unknown_X3.zip",
    "valid_bicubic_x3": _DL_URL + "DIV2K_valid_LR_bicubic_X3.zip",
    "valid_unknown_x3": _DL_URL + "DIV2K_valid_LR_unknown_X3.zip",
    "train_bicubic_x4": _DL_URL + "DIV2K_train_LR_bicubic_X4.zip",
    "train_unknown_x4": _DL_URL + "DIV2K_train_LR_unknown_X4.zip",
    "valid_bicubic_x4": _DL_URL + "DIV2K_valid_LR_bicubic_X4.zip",
    "valid_unknown_x4": _DL_URL + "DIV2K_valid_LR_unknown_X4.zip",
    "train_bicubic_x8": _DL_URL + "DIV2K_train_LR_x8.zip",
    "valid_bicubic_x8": _DL_URL + "DIV2K_valid_LR_x8.zip",
    "train_realistic_mild_x4": _DL_URL + "DIV2K_train_LR_mild.zip",
    "valid_realistic_mild_x4": _DL_URL + "DIV2K_valid_LR_mild.zip",
    "train_realistic_difficult_x4": _DL_URL + "DIV2K_train_LR_difficult.zip",
    "valid_realistic_difficult_x4": _DL_URL + "DIV2K_valid_LR_difficult.zip",
    "train_realistic_wild_x4": _DL_URL + "DIV2K_train_LR_wild.zip",
    "valid_realistic_wild_x4": _DL_URL + "DIV2K_valid_LR_wild.zip",
}

# Dataset folder names
_DATASET_FOLDERS = {
    "hr": "HR",
    "bicubic_x2": "LR_bicubic_X2",
    "bicubic_x3": "LR_bicubic_X3",
    "bicubic_x4": "LR_bicubic_X4",
    "bicubic_x8": "LR_bicubic_X8",
    "unknown_x2": "LR_unknown_X2",
    "unknown_x3": "LR_unknown_X3",
    "unknown_x4": "LR_unknown_X4",
    "realistic_mild_x4": "LR_realistic_mild_X4",
    "realistic_difficult_x4": "LR_realistic_difficult_X4",
    "realistic_wild_x4": "LR_realistic_wild_X4",
}

def download_file(url: str, dest_path: str) -> None:
    """Download a file from URL to destination path."""
    if os.path.exists(dest_path):
        print(f"File already exists: {dest_path}")
        return
    
    print(f"Downloading {url} to {dest_path}")
    os.makedirs(os.path.dirname(dest_path), exist_ok=True)
    
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_size = int(r.headers.get('content-length', 0))
        block_size = 8192
        
        with open(dest_path, 'wb') as f, tqdm(
                total=total_size, unit='B', unit_scale=True, desc=os.path.basename(dest_path)
            ) as pbar:
            for chunk in r.iter_content(chunk_size=block_size):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))

def extract_zip(zip_path: str, extract_to: str) -> None:
    """Extract a ZIP file to the specified directory."""
    print(f"Extracting {zip_path} to {extract_to}")
    os.makedirs(extract_to, exist_ok=True)
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for member in tqdm(zip_ref.infolist(), desc=f"Extracting {os.path.basename(zip_path)}"):
            zip_ref.extract(member, extract_to)

def get_image_paths(dir_path: str) -> List[str]:
    """Get all image paths in a directory."""
    extensions = {'.png', '.jpg', '.jpeg', '.bmp'}
    image_paths = []
    
    for root, _, files in os.walk(dir_path):
        for file in files:
            if any(file.lower().endswith(ext) for ext in extensions):
                image_paths.append(os.path.join(root, file))
    
    return sorted(image_paths)

def random_crop(image_path: str, crop_size: int) -> np.ndarray:
    """Create a random crop of size crop_size x crop_size from an image."""
    image = Image.open(image_path)
    width, height = image.size
    
    if width < crop_size or height < crop_size:
        # Resize the image if it's smaller than the crop size
        scale = max(crop_size / width, crop_size / height) * 1.1  # Add 10% margin
        new_width, new_height = int(width * scale), int(height * scale)
        image = image.resize((new_width, new_height), Image.LANCZOS)
        width, height = image.size
    
    left = random.randint(0, width - crop_size)
    top = random.randint(0, height - crop_size)
    
    crop = image.crop((left, top, left + crop_size, top + crop_size))
    return np.array(crop)

def process_image(args: Tuple[str, str, int, int]) -> None:
    """Process a single image to create and save a random crop."""
    image_path, save_path, crop_size, image_id = args
    
    # Create random crop
    crop = random_crop(image_path, crop_size)
    
    # Save the crop with zero-padded ID
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    Image.fromarray(crop).save(save_path)

def create_random_crops(
    image_paths: List[str], 
    output_dir: str, 
    crop_size: int, 
    num_crops: int,
    split: str
) -> None:
    """Create random crops from a list of images."""
    os.makedirs(output_dir, exist_ok=True)
    
    # Create task list for multithreading
    tasks = []
    image_id = 0
    
    # Keep creating crops until we reach the desired number
    while image_id < num_crops:
        # Cycle through available images
        for img_path in random.sample(image_paths, min(len(image_paths), num_crops - image_id)):
            save_path = os.path.join(output_dir, f"{image_id:08d}.png")
            tasks.append((img_path, save_path, crop_size, image_id))
            image_id += 1
            
            # Break if we've created enough tasks
            if image_id >= num_crops:
                break
    
    # Process tasks with thread pool
    with concurrent.futures.ThreadPoolExecutor() as executor:
        list(tqdm(
            executor.map(process_image, tasks),
            total=len(tasks),
            desc=f"Creating {split} crops"
        ))

def process_dataset(
    dataset_type: str,
    output_base_dir: str,
    crop_size: int,
    train_crops: int,
    valid_crops: int,
    test_crops: int,
    download_dir: str
) -> None:
    """Process a single dataset type (e.g. hr, bicubic_x2, etc.)."""
    print(f"\nProcessing dataset: {dataset_type}")
    
    # Determine folder names for this dataset
    folder_name = _DATASET_FOLDERS.get(dataset_type.replace("train_", "").replace("valid_", ""))
    if not folder_name:
        print(f"Unknown dataset type: {dataset_type}")
        return
    
    # Create output directory for this dataset
    output_dir = os.path.join(output_base_dir, f"DIV2K_{crop_size}", folder_name)
    os.makedirs(output_dir, exist_ok=True)
    
    # Process training data if available
    if dataset_type.startswith("train_"):
        # Download and extract training data
        train_url = _DL_URLS.get(f"train_{dataset_type}" if dataset_type == "hr" else dataset_type)
        if train_url:
            train_zip = os.path.join(download_dir, os.path.basename(train_url))
            train_extract_dir = os.path.join(download_dir, f"{dataset_type}")
            
            # Download and extract if necessary
            if not os.path.exists(train_zip):
                download_file(train_url, train_zip)
            if not os.path.exists(train_extract_dir):
                extract_zip(train_zip, train_extract_dir)
            
            # Find proper subdirectory for images
            train_img_dir = train_extract_dir
            if dataset_type == "hr":
                # For HR, images are in DIV2K_train_HR subfolder
                for item in os.listdir(train_extract_dir):
                    if item.startswith("DIV2K_train_HR"):
                        train_img_dir = os.path.join(train_extract_dir, item)
                        break
            
            # Get image paths and create crops
            train_images = get_image_paths(train_img_dir)
            print(f"Found {len(train_images)} training images")
            
            if train_crops > 0:
                train_output_dir = os.path.join(output_dir, "train")
                create_random_crops(train_images, train_output_dir, crop_size, train_crops, "train")
            

    
    # Process validation data if available
    if dataset_type.startswith("valid_"):
        # Download and extract validation data
        valid_url = _DL_URLS.get(f"valid_{dataset_type}" if dataset_type == "hr" else dataset_type)
        if valid_url:
            valid_zip = os.path.join(download_dir, os.path.basename(valid_url))
            valid_extract_dir = os.path.join(download_dir, f"{dataset_type}_valid")
            
            # Download and extract if necessary
            if not os.path.exists(valid_zip):
                download_file(valid_url, valid_zip)
            if not os.path.exists(valid_extract_dir):
                extract_zip(valid_zip, valid_extract_dir)
            
            # Find proper subdirectory for images
            valid_img_dir = valid_extract_dir
            if dataset_type == "hr":
                # For HR, images are in DIV2K_valid_HR subfolder
                for item in os.listdir(valid_extract_dir):
                    if item.startswith("DIV2K_valid_HR"):
                        valid_img_dir = os.path.join(valid_extract_dir, item)
                        break
            
            # Get image paths and create crops
            valid_images = get_image_paths(valid_img_dir)
            print(f"Found {len(valid_images)} validation images")
            
            if valid_crops > 0:
                valid_output_dir = os.path.join(output_dir, "valid")
                create_random_crops(valid_images, valid_output_dir, crop_size, valid_crops, "valid")

            # Also use valid images for test set since there's usually no separate test set
            if test_crops > 0:
                test_output_dir = os.path.join(output_dir, "test")
                create_random_crops(valid_images, test_output_dir, crop_size, test_crops, "test")

def main():
    """Main function to build the cropped dataset."""
    parser = argparse.ArgumentParser(description="Build cropped dataset from DIV2K")
    parser.add_argument("--output_dir", type=str, default="./DIV2K_cropped", 
                        help="Base output directory")
    parser.add_argument("--download_dir", type=str, default="./DIV2K_downloads", 
                        help="Directory to store downloaded files")
    parser.add_argument("--crop_size", type=int, default=128, 
                        help="Size of the crops (n×n)")
    parser.add_argument("--train_crops", type=int, default=10000, 
                        help="Number of training crops to generate")
    parser.add_argument("--valid_crops", type=int, default=2000, 
                        help="Number of validation crops to generate")
    parser.add_argument("--test_crops", type=int, default=2000, 
                        help="Number of test crops to generate")
    parser.add_argument("--seed", type=int, default=42, 
                        help="Random seed for reproducibility")
    parser.add_argument("--datasets", type=str, nargs="+", 
                        help="Specific datasets to process (leave empty for all)")
    
    args = parser.parse_args()
    
    # Set random seed for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    # Create directories
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.download_dir, exist_ok=True)
    
    # Determine which datasets to process
    datasets_to_process = args.datasets if args.datasets else [
        "hr","bicubic_x2", "bicubic_x3", "bicubic_x4"
    ]
    
    # Process each dataset
    for dataset_type in datasets_to_process:
        process_dataset(
            "train_"+dataset_type,
            args.output_dir,
            args.crop_size,
            args.train_crops,
            args.valid_crops,
            args.test_crops,
            args.download_dir
        )
        process_dataset(
            "valid_"+dataset_type,
            args.output_dir,
            args.crop_size,
            args.train_crops,
            args.valid_crops,
            args.test_crops,
            args.download_dir
        )
    
    print(f"\nDone! Output directory: {args.output_dir}")

if __name__ == "__main__":
    main()